package command

import (
	"encoding/json"
	"fmt"
	_ "github.com/logoove/sqlite"
	"modernc.org/mathutil"
	"os"
	_ "path"
	"path/filepath"
	"redis-check/common"
	"redis-check/tool"
	"sort"
	"strings"
)

type StatOptions struct {
	common.BaseOptions
	Input     string `short:"i" long:"input" default:"" description:"RDB文件路径，指定文件或者文件夹，会自动读取以.rdb结尾的文件"`
	Output    string `short:"o" long:"output" default:"" description:"统计结果输出目录"`
	MaxLevel  int32  `short:"m" long:"max-level" default:"1" min:"1" max:"5" description:"Key最大的层级，例如：aaa:bbb:ccc，ccc为第3级"`
	MaxLength int32  `short:"l" long:"max-length" default:"20" min:"1" max:"50" description:"单级Key的最大长度"`
	Format    string `short:"f" long:"format" default:"text" choices:"text,json" description:"输出文件格式，可选text、json"`
	Sort      string `short:"s" long:"sort" default:"key" choices:"key,count,bytes,elems" description:"排序方式，可选key、count、bytes、elems"`
	Prefix    string `short:"p" long:"prefix" default:"" description:"Key前缀，只统计此前缀的Key"`
	Bytes     uint64 `long:"bytes" default:"100" description:"大Key字节数阀值，单位：M"`
	Elems     uint64 `long:"elems" default:"50" description:"大Key元素个数阀值，单位：万"`
	Merge     bool   `long:"merge" description:"是否合并多个RDB文件的统计结果"`
	Key       bool   `long:"key" description:"开启输出所有的Key"`
	Detail    bool   `long:"detail" description:"开启输出所有Key的明细，包含Key、数量、字节数、元素数量"`
}

type StatCommand struct {
	Options    *StatOptions
	statItems  map[string]*StatInfo
	totalCount uint64
	totalBytes uint64
	totalElems uint64
	keyFile    *os.File
}

type StatInfo struct {
	Key     string
	Count   uint64
	Bytes   uint64
	Elems   uint64
	Big     uint64
	Percent float32
}

type StatHead struct {
	Name   string
	Len    int
	Format string
}

func (p *StatCommand) Execute(args []string) error {
	ret, err := common.RunWithArgs(p.Options, os.Args[2:])
	if !ret {
		return err
	}
	if p.Options.Input == "" {
		p.Options.Input = common.ScanConf("请输入RDB路径", false)[0]
	}
	p.Options.Bytes = p.Options.Bytes * 1024 * 1024
	p.Options.Elems = p.Options.Elems * 10000
	common.Logger.Infof("开始统计，RDB路径：%v", p.Options.Input)
	err = common.ScanPath(common.AbsPath(p.Options.Input), p.handleStat)
	if err == nil && p.Options.Merge {
		err = p.saveResult(common.AbsPath(p.Options.Output), "redis_stat", true)
	}
	if err != nil {
		common.Logger.Infof("统计异常：%v", err.Error())
	}
	common.Logger.Infof("--------------- 统计完成 ----------------")
	return nil
}

func (p *StatCommand) outputDirAndName(file *os.File) (string, string) {
	dir, name := filepath.Split(file.Name())
	return common.DeftStr(p.Options.Output, dir), common.FileName(name)
}

func (p *StatCommand) handleStat(file *os.File) error {
	if strings.ToLower(filepath.Ext(file.Name())) != ".rdb" {
		return nil
	}
	p.initStat()
	common.Logger.Infof("解析RDB文件：%v", file.Name())
	output, name := p.outputDirAndName(file)
	if p.Options.Key || p.Options.Detail {
		keyFile, err := p.initKeyFile(output, name)
		if err != nil {
			return err
		}
		p.keyFile = keyFile
		defer keyFile.Close()
	}
	err := tool.DecodeRdbFile(file, p.statEntry)
	if err != nil {
		common.Logger.Warnf("解析RDB文件异常：%v", err.Error())
		return err
	}
	return p.saveResult(output, name, false)
}

func (p *StatCommand) initStat() {
	if p.statItems != nil && p.Options.Merge {
		return
	}
	p.statItems = make(map[string]*StatInfo)
	p.totalCount = 0
	p.totalBytes = 0
	p.totalElems = 0
	p.keyFile = nil
}

func (p *StatCommand) statEntry(entry *tool.RdbEntry) {
	if p.Options.Prefix != "" && !strings.HasPrefix(entry.Key, p.Options.Prefix) {
		return
	}
	statKey := p.toStatKey(entry)
	statInfo := p.statItems[statKey]
	if statInfo == nil {
		statInfo = &StatInfo{Key: statKey}
		p.statItems[statKey] = statInfo
	}
	statInfo.Count++
	statInfo.Bytes += entry.Bytes
	statInfo.Elems += entry.Elems
	if (p.Options.Bytes > 0 && entry.Bytes > p.Options.Bytes) || (p.Options.Elems > 0 && entry.Elems > p.Options.Elems) {
		statInfo.Big++
	}
	p.totalCount++
	p.totalBytes += entry.Bytes
	p.totalElems += entry.Elems
	p.writeKeyInfo(entry)
}

func (p *StatCommand) toStatKey(entry *tool.RdbEntry) string {
	levels := strings.Split(entry.Key, ":")
	for i, item := range levels {
		if int32(len(item)) > p.Options.MaxLength {
			levels[i] = item[0:p.Options.MaxLength] + "*"
		}
	}
	return strings.Join(levels[0:mathutil.MinInt32(int32(len(levels)), p.Options.MaxLevel)], ":")
}

func (p *StatCommand) initKeyFile(output string, name string) (*os.File, error) {
	keyFile, err := common.CreateFile(output, name+".key", false)
	if err != nil {
		return nil, err
	}
	if p.Options.Detail {
		keyFile.WriteString(fmt.Sprintf("%v\t%v\t%v\t%v\n", "Key", "Type", "Elems", "Bytes"))
	}
	return keyFile, nil
}

func (p *StatCommand) writeKeyInfo(entry *tool.RdbEntry) {
	if p.keyFile != nil {
		if p.Options.Detail {
			p.keyFile.WriteString(fmt.Sprintf("%v\t%v\t%v\t%v\n", entry.Key, entry.Type, entry.Elems, entry.Bytes))
		} else {
			p.keyFile.WriteString(entry.Key + "\n")
		}
	}
}

func (p *StatCommand) saveResult(dir string, name string, finish bool) error {
	if p.keyFile != nil {
		common.Logger.Infof("Key输出目录：%v", p.keyFile.Name())
	}
	if !finish && p.Options.Merge {
		return nil
	}
	common.Logger.Infof("解析RDB文件完成，Key总数：%v，Value总数：%v，统计总数：%v，存储大小：%v", p.totalCount, p.totalElems, len(p.statItems), common.ShowBytes(p.totalBytes))
	if p.Options.Format == "json" {
		return p.saveResultAsJson(dir, name)
	} else if p.Options.Format == "text" {
		return p.saveResultAsText(dir, name)
	} else {
		return p.saveResultAsLog()
	}
}

func (p *StatCommand) saveResultAsText(path string, name string) error {
	file, err := common.CreateFile(path, name+".txt", false)
	if err != nil {
		return err
	}
	defer file.Close()
	_, err = file.WriteString(p.resultToText())
	if err != nil {
		return err
	}
	common.Logger.Infof("保存统计结果：%v", file.Name())
	return nil
}

func (p *StatCommand) saveResultAsJson(path string, name string) error {
	file, err := common.CreateFile(path, name+".json", false)
	if err != nil {
		return err
	}
	common.Logger.Infof("保存统计结果：%v", file.Name())
	defer file.Close()
	encoder := json.NewEncoder(file)
	encoder.Encode(p.toResultList())
	return nil
}

func (p *StatCommand) saveResultAsLog() error {
	common.Logger.Infof("统计结果：\n" + p.resultToText())
	return nil
}

func (p *StatCommand) resultToText() string {
	heads := [...]*StatHead{{Name: "Key"}, {Name: "Count"}, {Name: "Bytes"}, {Name: "Elems"}, {Name: "Percent"}, {Name: "Big"}}
	rows := make([][len(heads)]string, len(p.statItems))
	for index, value := range p.toResultList() {
		row := [...]string{value.Key, common.ToString(value.Count), common.ShowBytes(value.Bytes), common.ToString(value.Elems), fmt.Sprintf("%.6f", value.Percent), common.ToString(value.Big)}
		rows[index] = row
		for i := range heads {
			head := heads[i]
			head.Len = mathutil.Max(head.Len, len(row[i]))
		}
	}
	builder := strings.Builder{}
	for _, head := range heads {
		head.Len = mathutil.Max(head.Len, len(head.Name)) + 2
		head.Format = "%-" + common.ToString(head.Len) + "v"
		builder.WriteString(fmt.Sprintf(head.Format, head.Name))
	}
	for _, row := range rows {
		builder.WriteString("\n")
		for i, value := range row {
			builder.WriteString(fmt.Sprintf(heads[i].Format, value))
		}
	}
	builder.WriteString("\n")
	return builder.String()
}

func (p *StatCommand) toResultList() []*StatInfo {
	index := 0
	list := make([]*StatInfo, len(p.statItems))
	for _, value := range p.statItems {
		value.Percent = float32(value.Bytes) * 100.0 / float32(p.totalBytes)
		list[index] = value
		index++
	}
	sort.Slice(list, func(i, j int) bool {
		if p.Options.Sort == "key" {
			return strings.Compare(list[i].Key, list[j].Key) < 0
		} else if p.Options.Sort == "bytes" {
			return list[i].Bytes > list[j].Bytes
		} else if p.Options.Sort == "elems" {
			return list[i].Elems > list[j].Elems
		} else {
			return list[i].Count > list[j].Count
		}
	})
	return list
}
